{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gensim.models import Word2Vec\n",
    "import pandas as pd\n",
    "import jieba\n",
    "import pickle\n",
    "\n",
    "# 设置待训练的Embedding dims\n",
    "embed_size = 300"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Char级别Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.read_csv(\"../jupyter/shuffle-data/train_data.csv\")\n",
    "dev_data = pd.read_csv(\"../jupyter/shuffle-data/dev_data.csv\")\n",
    "\n",
    "# 训练字向量的w2v\n",
    "char_train_list = []\n",
    "for idx in train_data.index:\n",
    "    query_1 = train_data.iloc[idx,1]\n",
    "    query_2 = train_data.iloc[idx,2]\n",
    "    for q in [query_1,query_2]:\n",
    "        line = [word for word in q]\n",
    "        if line not in char_train_list:\n",
    "            char_train_list.append(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['剧', '烈', '运', '动', '后', '咯', '血', '，', '是', '怎', '么', '了', '？'],\n",
       " ['剧', '烈', '运', '动', '后', '咯', '血', '是', '什', '么', '原', '因', '？'],\n",
       " ['剧', '烈', '运', '动', '后', '为', '什', '么', '会', '咯', '血', '？'],\n",
       " ['剧', '烈', '运', '动', '后', '咯', '血', '，', '应', '该', '怎', '么', '处', '理', '？'],\n",
       " ['剧', '烈', '运', '动', '后', '咯', '血', '，', '需', '要', '就', '医', '吗', '？']]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "char_train_list[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec(char_train_list,size=embed_size,window=5, min_count=1, workers=4)\n",
    "with open(\"w2v_char_\" + str(embed_size) + \".pkl\",\"wb\") as f:\n",
    "    pickle.dump(model,f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word级别Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /var/folders/j1/ls86yccj7l5dyscbpmp85ngw0000gn/T/jieba.cache\n",
      "Loading model cost 0.819 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[['剧烈运动', '后', '咯血', '，', '是', '怎么', '了', '？'],\n",
       " ['剧烈运动', '后', '咯血', '是', '什么', '原因', '？'],\n",
       " ['剧烈运动', '后', '为什么', '会', '咯血', '？'],\n",
       " ['剧烈运动', '后', '咯血', '，', '应该', '怎么', '处理', '？'],\n",
       " ['剧烈运动', '后', '咯血', '，', '需要', '就医', '吗', '？'],\n",
       " ['剧烈运动', '后', '咯血', '，', '是否', '很', '严重', '？'],\n",
       " ['百令', '胶囊', '需要', '注意', '什么', '？'],\n",
       " ['百令', '胶囊', '有', '什么', '注意事项', '？'],\n",
       " ['服用', '百令', '胶囊', '有', '什么', '需要', '特别', '注意', '的', '吗', '？'],\n",
       " ['百令', '胶囊', '如何', '服用', '？']]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练词级别的w2v\n",
    "word_train_list = []\n",
    "for idx in train_data.index:\n",
    "    query_1 = train_data.iloc[idx,1]\n",
    "    query_2 = train_data.iloc[idx,2]\n",
    "    for q in [query_1,query_2]:\n",
    "        line = [word for word in jieba.lcut(q)]\n",
    "        if line not in word_train_list:\n",
    "            word_train_list.append(line)\n",
    "\n",
    "word_train_list[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Word2Vec(word_train_list,size=embed_size,window=5, min_count=1, workers=4)\n",
    "with open(\"w2v_word_\" + str(embed_size) + \".pkl\",\"wb\") as f:\n",
    "    pickle.dump(model,f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "# char_model = Word2Vec.load(\"w2v_char_300.pkl\")\n",
    "# word_model = Word2Vec.load(\"w2v_word_300.pkl\")\n",
    "\n",
    "# word_model.wv.vocab\n",
    "# char_model.wv.vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 演练一下构建word2idx和对应emb_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {\"_PAD\": 0, \"_UNK\": 1}\n",
    "\n",
    "embedding_matrix = np.zeros((len(model.wv.vocab) + 2, model.vector_size))\n",
    "\n",
    "unk = np.random.random(size=model.vector_size)\n",
    "unk = unk - unk.mean()\n",
    "embedding_matrix[1] = unk\n",
    "\n",
    "for word in model.wv.vocab.keys():\n",
    "    idx = len(word2idx)\n",
    "    word2idx[word] = idx\n",
    "    embedding_matrix[idx] = model.wv[word]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3097"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(word2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3097"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(embedding_matrix)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf2_py37",
   "language": "python",
   "name": "tf2_py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
